-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir] Generalize OneShotModuleBufferize to operate on any Operation #148327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir] Generalize OneShotModuleBufferize to operate on any Operation #148327
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Evan Liu (Evanyl) ChangesThere was a commit to change Patch is 32.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148327.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
similarity index 64%
rename from mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
index 2cf801dd1d951..32b76269e2c03 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
@@ -1,4 +1,4 @@
-//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===//
+//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
-#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
+#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
namespace llvm {
struct LogicalResult;
} // namespace llvm
namespace mlir {
-class ModuleOp;
+class Operation;
namespace bufferization {
struct BufferizationStatistics;
@@ -22,11 +22,11 @@ class OneShotAnalysisState;
struct OneShotBufferizationOptions;
class BufferizationState;
-/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
+/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
llvm::LogicalResult
-analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
- BufferizationStatistics *statistics = nullptr);
+analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state,
+ BufferizationStatistics *statistics = nullptr);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///
@@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
/// will be inserted only to these FuncOps.
llvm::LogicalResult
-bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state,
- BufferizationStatistics *statistics = nullptr);
+bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
+ BufferizationStatistics *statistics = nullptr);
-/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
-void removeBufferizationAttributesInModule(ModuleOp moduleOp);
+/// Remove bufferization attributes on every FuncOp arguments in the RootOp.
+void removeBufferizationAttributesInRoot(Operation *rootOp);
-/// Run One-Shot Module Bufferization on the given module. Performs a simple
+/// Run One-Shot Root Bufferization on the given root op. Performs a simple
/// function call analysis to determine which function arguments are
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
/// Bufferize.
-llvm::LogicalResult runOneShotModuleBufferize(
- ModuleOp moduleOp,
+llvm::LogicalResult runOneShotRootBufferize(
+ Operation *rootOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..5a52daf6c7698 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -10,7 +10,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
- bufferizationState)))
+ if (failed(bufferization::runOneShotRootBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
if (failed(bufferization::runOneShotBufferize(target, options,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 246555dc8c699..baef091eeebd1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -12,7 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Diagnostics.h"
@@ -163,8 +163,7 @@ struct OneShotBufferizePass
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
- if (failed(
- runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+ if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..fa310b95df4bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
FuncBufferizableOpInterfaceImpl.cpp
LowerDeallocations.cpp
OneShotAnalysis.cpp
- OneShotModuleBufferize.cpp
+ OneShotRootBufferize.cpp
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index b7db2e847a335..ee1a9178a9d0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -11,7 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
@@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
OneShotAnalysisState state(op, options);
if (moduleOp) {
// Module analysis takes into account function boundaries.
- if (failed(analyzeModuleOp(moduleOp, state)))
+ if (failed(analyzeRootOp(moduleOp, state)))
return failure();
} else {
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
similarity index 85%
rename from mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
index d1d106220a38c..a7865050e6e38 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +7,12 @@
//
//===----------------------------------------------------------------------===//
//
-// Module Bufferization is an extension of One-Shot Bufferize that
+// Root Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
+// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`.
+// This function analyzes the given op and determines the order of analysis
// and bufferization: Functions that are called are processed before their
// respective callers.
//
@@ -24,7 +25,7 @@
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
// read/written.
//
-// Module Bufferization implements the following calling convention.
+// Root Bufferization implements the following calling convention.
//
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
// be written to in-place.
@@ -57,7 +58,7 @@
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
// as "not reading" and/or "not writing".
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
llvm::IsaPred<TensorType>);
}
-/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
+/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
/// `remainingFuncOps`. Does not traverse nested symbol tables.
@@ -309,7 +310,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *rootOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +318,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
- OneShotAnalysisState &state,
- BufferizationStatistics *statistics) {
+mlir::bufferization::analyzeRootOp(Operation *rootOp,
+ OneShotAnalysisState &state,
+ BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
@@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap,
- funcState.symbolTables)))
+ if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+ callerMap, funcState.symbolTables)))
return failure();
// Analyze functions in order. Starting with functions that are not calling
@@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return success();
}
-void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+void mlir::bufferization::removeBufferizationAttributesInRoot(
+ Operation *rootOp) {
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
-LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::bufferizeRootOp(
+ Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(rootOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// accurate buffer types for function return values. Functions that call
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap,
- state.getSymbolTables())))
+ if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+ callerMap, state.getSymbolTables())))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);
@@ -571,22 +577,27 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
- removeBufferizationAttributesInModule(moduleOp);
+ removeBufferizationAttributesInRoot(rootOp);
return success();
}
-LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::runOneShotRootBufferize(
+ Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
@@ -594,7 +605,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
"invalid combination of bufferization flags");
if (!options.copyBeforeWrite) {
if (options.noAnalysisFuncFilter.empty()) {
- if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
+ if (failed(insertTensorCopies(rootOp, options, state, statistics)))
return failure();
} else {
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
};
OneShotBufferizationOptions updatedOptions(options);
updatedOptions.opFilter.denyOperation(analysisFilterFn);
- if (failed(
- insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
+ if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics)))
return failure();
}
}
if (options.testAnalysisOnly)
return success();
- if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
+ if (failed(bufferizeRootOp(moduleOp, options, state, statistics)))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsert...
[truncated]
|
@llvm/pr-subscribers-mlir-bufferization Author: Evan Liu (Evanyl) ChangesThere was a commit to change Patch is 32.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148327.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
similarity index 64%
rename from mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
index 2cf801dd1d951..32b76269e2c03 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
@@ -1,4 +1,4 @@
-//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===//
+//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
-#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
+#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
namespace llvm {
struct LogicalResult;
} // namespace llvm
namespace mlir {
-class ModuleOp;
+class Operation;
namespace bufferization {
struct BufferizationStatistics;
@@ -22,11 +22,11 @@ class OneShotAnalysisState;
struct OneShotBufferizationOptions;
class BufferizationState;
-/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
+/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
llvm::LogicalResult
-analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
- BufferizationStatistics *statistics = nullptr);
+analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state,
+ BufferizationStatistics *statistics = nullptr);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///
@@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
/// will be inserted only to these FuncOps.
llvm::LogicalResult
-bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state,
- BufferizationStatistics *statistics = nullptr);
+bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
+ BufferizationStatistics *statistics = nullptr);
-/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
-void removeBufferizationAttributesInModule(ModuleOp moduleOp);
+/// Remove bufferization attributes on every FuncOp arguments in the RootOp.
+void removeBufferizationAttributesInRoot(Operation *rootOp);
-/// Run One-Shot Module Bufferization on the given module. Performs a simple
+/// Run One-Shot Root Bufferization on the given root op. Performs a simple
/// function call analysis to determine which function arguments are
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
/// Bufferize.
-llvm::LogicalResult runOneShotModuleBufferize(
- ModuleOp moduleOp,
+llvm::LogicalResult runOneShotRootBufferize(
+ Operation *rootOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..5a52daf6c7698 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -10,7 +10,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
- bufferizationState)))
+ if (failed(bufferization::runOneShotRootBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
if (failed(bufferization::runOneShotBufferize(target, options,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 246555dc8c699..baef091eeebd1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -12,7 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Diagnostics.h"
@@ -163,8 +163,7 @@ struct OneShotBufferizePass
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
- if (failed(
- runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+ if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..fa310b95df4bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
FuncBufferizableOpInterfaceImpl.cpp
LowerDeallocations.cpp
OneShotAnalysis.cpp
- OneShotModuleBufferize.cpp
+ OneShotRootBufferize.cpp
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index b7db2e847a335..ee1a9178a9d0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -11,7 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
@@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
OneShotAnalysisState state(op, options);
if (moduleOp) {
// Module analysis takes into account function boundaries.
- if (failed(analyzeModuleOp(moduleOp, state)))
+ if (failed(analyzeRootOp(moduleOp, state)))
return failure();
} else {
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
similarity index 85%
rename from mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
index d1d106220a38c..a7865050e6e38 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +7,12 @@
//
//===----------------------------------------------------------------------===//
//
-// Module Bufferization is an extension of One-Shot Bufferize that
+// Root Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
+// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`.
+// This function analyzes the given op and determines the order of analysis
// and bufferization: Functions that are called are processed before their
// respective callers.
//
@@ -24,7 +25,7 @@
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
// read/written.
//
-// Module Bufferization implements the following calling convention.
+// Root Bufferization implements the following calling convention.
//
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
// be written to in-place.
@@ -57,7 +58,7 @@
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
// as "not reading" and/or "not writing".
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
llvm::IsaPred<TensorType>);
}
-/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
+/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
/// `remainingFuncOps`. Does not traverse nested symbol tables.
@@ -309,7 +310,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *rootOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +318,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
- OneShotAnalysisState &state,
- BufferizationStatistics *statistics) {
+mlir::bufferization::analyzeRootOp(Operation *rootOp,
+ OneShotAnalysisState &state,
+ BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
@@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap,
- funcState.symbolTables)))
+ if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+ callerMap, funcState.symbolTables)))
return failure();
// Analyze functions in order. Starting with functions that are not calling
@@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return success();
}
-void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+void mlir::bufferization::removeBufferizationAttributesInRoot(
+ Operation *rootOp) {
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
-LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::bufferizeRootOp(
+ Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(rootOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// accurate buffer types for function return values. Functions that call
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap,
- state.getSymbolTables())))
+ if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+ callerMap, state.getSymbolTables())))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);
@@ -571,22 +577,27 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region ®ion : rootOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
- removeBufferizationAttributesInModule(moduleOp);
+ removeBufferizationAttributesInRoot(rootOp);
return success();
}
-LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::runOneShotRootBufferize(
+ Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
@@ -594,7 +605,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
"invalid combination of bufferization flags");
if (!options.copyBeforeWrite) {
if (options.noAnalysisFuncFilter.empty()) {
- if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
+ if (failed(insertTensorCopies(rootOp, options, state, statistics)))
return failure();
} else {
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
};
OneShotBufferizationOptions updatedOptions(options);
updatedOptions.opFilter.denyOperation(analysisFilterFn);
- if (failed(
- insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
+ if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics)))
return failure();
}
}
if (options.testAnalysisOnly)
return success();
- if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
+ if (failed(bufferizeRootOp(moduleOp, options, state, statistics)))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsert...
[truncated]
|
cc @christopherbate, thanks! |
e35c80c
to
2aa588e
Compare
2aa588e
to
75d70e0
Compare
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp mlir/tools/mlir-opt/mlir-opt.cpp mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 5fff7da99..aceeef4ef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -116,8 +116,7 @@ public:
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeRootOp(getOperation(),
- updatedOptions,
+ if (failed(bufferization::bufferizeRootOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
|
There was a commit to change
OneShotModuleBufferize
to no longer descend into nested symbol tables, recommending users who wish to do this should do so in a pass pipeline/custom pass. This did not support the use case of ops that weren't ModuleOps. I propose changingOneShotModuleBufferize
toOneShotRootBufferize
, and to operate on any general Operation.